Introduction to Reservoir Computing with ReservoirPy¶

Nathan Trouvain
Inria, IMN, LaBRI - Bordeaux, France




UCLA - November 14th 2023

Table of content¶

  • Key conceptss
  • Timeseries prediction
  • Learned attractors and generative capabilities
  • Online learning
  • Understanding hyperparameters
  • Use Case demo - falling robot
  • Use Case demo - canary songs transcriber

Key oncepts and features ¶

  • Numpy, Scipy, and that's it!
  • Efficient execution (distributed implementation)
  • Online and offline learning rules
  • Complex model architectures enabled
  • Documentation: https://reservoirpy.readthedocs.io/en/latest/
  • GitHub: https://github.com/reservoirpy/reservoirpy

General info¶

  • Everything is NumPy (and more generally "standard" scientific Python)
  • First axis of arrays is always representing time.

Timeseries prediction ¶

The Lorenz attractor

$$ \begin{split} \dot{x}(t) &= \sigma (y(t) - x(t)) \\ \dot{y}(t) &= \rho x(t) - y(t) - x(t)z(t) \\ \dot{z}(t) &= x(t)y(t) - \beta z(t) \end{split} $$

  • describes convection movements in a fluid. Highly chaotic!
In [4]:
from reservoirpy.datasets import lorenz

timesteps = 3000
X = lorenz(timesteps, x0=[17.67, 12.93, 43.91])
In [6]:
plot_lorenz(X, 1000)
No description has been provided for this image

Knowing series value at timestep $t$:

  • How can we predict $t+1$, $t+100$...?
  • How can we predict $t+1$, $t+2$, $\dots$, $t+n$ ?

10-steps ahead forecasting¶

Predict $P(t + 10)$ knowing $P(t)$.

In [8]:
from reservoirpy.datasets import to_forecasting

x, y = to_forecasting(X, forecast=10)
X_train1, y_train1 = x[:2000], y[:2000]
X_test1, y_test1 = x[2000:], y[2000:]

plot_train_test(X_train1, y_train1, X_test1, y_test1)
No description has been provided for this image

Some reservoir reminder¶

No description has been provided for this image

$$ x[t+1]= (1 - \alpha) x[t] + \alpha f(W \cdot x[t] + W_{in} \cdot u[t] + W_{fb} \cdot y[t]) $$ $$ y[t]= W_{out}^{\intercal} x[t] $$

Echo State Networks with ReservoirPy¶

Models are built our of Nodes:

No description has been provided for this image
No description has been provided for this image

ESN preparation¶

In [10]:
units = 100               # - number of units
leak_rate = 0.3           # - leaking rate
spectral_radius = 0.95    # - spectral radius
input_scaling = 0.5       # - input scaling (also called input gain)
connectivity = 0.1        # - recurrent weights connectivity probability
input_connectivity = 0.2  # - input weights connectivity probability
regularization = 1e-4     # - L2 regularization coeficient
transient = 100           # - number of warmup steps
seed = 1234               # - use for reproducibility
In [12]:
from reservoirpy.nodes import Reservoir, Ridge

reservoir = Reservoir(units, input_scaling=input_scaling, sr=spectral_radius,
                      lr=leak_rate, rc_connectivity=connectivity,
                      input_connectivity=input_connectivity, seed=seed)

readout   = Ridge(ridge=regularization)
No description has been provided for this image
In [13]:
esn = reservoir >> readout
No description has been provided for this image
In [14]:
reservoir_fb = reservoir << readout

esn_fb = reservoir_fb >> readout

ESN training¶

Learning is performed offline: model is updated only once, on all available training data.

In [15]:
esn.fit(X_train1, y_train1, warmup=transient);
In [17]:
plot_readout(readout)
No description has been provided for this image

ESN forecast¶

In [18]:
y_pred1 = esn.run(X_test1)
In [19]:
plot_results(y_pred1, y_test1)
No description has been provided for this image

Coefficient de détermination $R^2$ et erreur quadratique normalisée :

In [20]:
rsquare(y_test1, y_pred1), nrmse(y_test1, y_pred1)
Out[20]:
(0.9966686152092658, 0.011448035696843583)

Closed-loop reservoir¶

  • Train the ESN on solving a 1-step ahead prediction (learn the flow function $f(x_t) = x_{t+1})$
  • Use ESN to predict on its own activity ("generative" mode).
No description has been provided for this image
In [21]:
units = 300               # - number of units
leak_rate = 0.3           # - leaking rate
spectral_radius = 1.25    # - spectral radius
input_scaling = 0.1       # - input scaling (also called input gain)
connectivity = 0.1        # - recurrent weights connectivity probability
input_connectivity = 0.2  # - input weights connectivity probability
regularization = 1e-4     # - L2 regularization coeficient
transient = 100           # - number of warmup steps
seed = 1234               # - use for reproducibility

Forecast of close future¶

In [23]:
esn = reset_esn()

x, y = to_forecasting(X, forecast=1)
X_train3, y_train3 = x[:2000], y[:2000]
X_test3, y_test3 = x[2000:], y[2000:]

esn = esn.fit(X_train3, y_train3, warmup=transient)

Closed-loop model¶

  • 100 timesteps used as warmup;
  • 300 timesteps created from reservoir dynamics, without external inputs.
In [24]:
seed_timesteps = 100

warming_inputs = X_test3[:seed_timesteps]

warming_out = esn.run(warming_inputs)  # échauffement
In [25]:
nb_generations = 500

X_gen = np.zeros((nb_generations, 3))
y = warming_out[-1]
for t in range(nb_generations):  # génération
    y = esn(y)
    X_gen[t, :] = y
In [26]:
X_t = X_test3[seed_timesteps: nb_generations+seed_timesteps]
plot_generation(X_gen, X_t, warming_out=warming_out, warming_inputs=warming_inputs)
No description has been provided for this image
In [27]:
plot_attractors(X_gen, X_t, warming_inputs, warming_out)
No description has been provided for this image

Online learning ¶

Online learning happens anytime and incrementally.

In the following, we will use Recursive Least Squares algorithm.

No description has been provided for this image
In [29]:
from reservoirpy.nodes import RLS

reservoir = Reservoir(units, input_scaling=input_scaling, sr=spectral_radius,
                      lr=leak_rate, rc_connectivity=connectivity,
                      input_connectivity=input_connectivity, seed=seed)

readout   = RLS()  # Recursive Least Squares


esn_online = reservoir >> readout

Step-by-step training¶

In [30]:
outputs_pre = np.zeros(X_train1.shape)
for t, (x, y) in enumerate(zip(X_train1, y_train1)): # for each timestep do :
    prediction = esn_online.train(np.atleast_2d(x), np.atleast_2d(y))
    outputs_pre[t, :] = prediction
In [31]:
plot_results(outputs_pre, y_train1, sample=100)
No description has been provided for this image
In [32]:
plot_results(outputs_pre, y_train1, sample=500)
No description has been provided for this image

Whole sequence training¶

In [34]:
esn_online.train(X_train1, y_train1)

pred_online = esn_online.run(X_test1)  # Wout is now learned and fixed
In [35]:
plot_results(pred_online, y_test1, sample=500)
No description has been provided for this image

Determination coefficient $R^2$ and NRMSE:

In [36]:
rsquare(y_test1, pred_online), nrmse(y_test1, pred_online)
Out[36]:
(0.9954163172865621, 0.013428448857951327)

Diving in the reservoir¶

In [37]:
units = 300               # - number of units
leak_rate = 0.3           # - leaking rate
spectral_radius = 1.25    # - spectral radius
input_scaling = 0.1       # - input scaling (also called input gain)
connectivity = 0.1        # - recurrent weights connectivity probability
input_connectivity = 0.2  # - input weights connectivity probability
regularization = 1e-4     # - L2 regularization coeficient
transient = 100           # - number of warmup steps
seed = 1234               # - use for reproducibility

1. The spectral radius¶

The spectral radius is the recurrent weights matrix ($W$) largest absolute eigenvalue.

In [38]:
states = []
radii = [0.1, 1.25, 10.0]
for sr in radii:
    reservoir = Reservoir(units, sr=sr, input_scaling=0.001, lr=leak_rate, rc_connectivity=connectivity,
                         input_connectivity=input_connectivity)

    s = reservoir.run(X_test1[:500])
    states.append(s)
In [39]:
units_nb = 20

plt.figure(figsize=(15, 8))
for i, s in enumerate(states):
    plt.subplot(len(radii)*100+10+i+1)
    plt.plot(s[:, :units_nb], alpha=0.6)
    plt.ylabel(f"$sr={radii[i]}$")
plt.xlabel(f"Activations ({units_nb} neurons)")
plt.show()
No description has been provided for this image
  • $-$ rayon spectral $\rightarrow$ stable dynamics

  • $+$ rayon spectral $\rightarrow$ chaotic dynamics

2. The input scaling¶

It is a coefficient applied to $W_{in}$. It can be seen as a gain applied on inputs.

In [40]:
states = []
scalings = [0.00001, 0.001, 2.0]
for iss in scalings:
    reservoir = Reservoir(units, sr=spectral_radius, input_scaling=iss, lr=leak_rate,
                          rc_connectivity=connectivity, input_connectivity=input_connectivity)

    s = reservoir.run(X_test1[:500])
    states.append(s)
In [41]:
units_nb = 20

plt.figure(figsize=(15, 8))
for i, s in enumerate(states):
    plt.subplot(len(scalings)*100+10+i+1)
    plt.plot(s[:, :units_nb], alpha=0.6)
    plt.ylabel(f"$iss={scalings[i]}$")
plt.xlabel(f"Activations ({units_nb} neurons)")
plt.show()
No description has been provided for this image

Average correlation of reservoir states and inputs :

  • $+$ input scaling $\rightarrow$ activity is bounded to input dynamics
  • $-$ input scaling $\rightarrow$ activity is freely evolving

Input scaling may be used to adjust relative importance of different inputs.

3. The leaking rate¶

$$ x(t+1) = \underbrace{\color{red}{(1 - \alpha)} x(t)}_{\text{current state}} + \underbrace{\color{red}\alpha f(u(t+1), x(t))}_{\text{new inputs}} $$

with $\alpha \in [0, 1]$ and:

$$ f(u, x) = \tanh(W_{in} \cdotp u + W \cdotp x) $$

In [42]:
states = []
rates = [0.02, 0.2, 0.9]
for lr in rates:
    reservoir = Reservoir(units, sr=spectral_radius, input_scaling=input_scaling, lr=lr,
                          rc_connectivity=connectivity, input_connectivity=input_connectivity)

    s = reservoir.run(X_test1[:500])
    states.append(s)
In [43]:
units_nb = 20

plt.figure(figsize=(15, 8))
for i, s in enumerate(states):
    plt.subplot(len(rates)*100+10+i+1)
    plt.plot(s[:, :units_nb] + 2*i)
    plt.ylabel(f"$lr={rates[i]}$")
plt.xlabel(f"States ({units_nb} neurons)")
plt.show()
No description has been provided for this image
  • $+$ leaking rate $\rightarrow$ low inertia, short activity timescale
  • $-$ leaking rate $\rightarrow$ strong inertia, strong activity timescale

The leaking rate is a proxy of the inverse of the reservoir neurons time constant.

Use case: falling robot ¶

No description has been provided for this image
In [45]:
features = ['com_x', 'com_y', 'com_z', 'trunk_pitch', 'trunk_roll', 'left_x', 'left_y',
            'right_x', 'right_y', 'left_ankle_pitch', 'left_ankle_roll', 'left_hip_pitch',
            'left_hip_roll', 'left_hip_yaw', 'left_knee', 'right_ankle_pitch',
            'right_ankle_roll', 'right_hip_pitch', 'right_hip_roll',
            'right_hip_yaw', 'right_knee']

prediction = ['fallen']
force = ['force_orientation', 'force_magnitude']
In [50]:
plot_robot(Y, Y_train, F)
No description has been provided for this image

ESN training¶

Using ESN class, an optimized and distributed implementation of Echo State Network.

In [52]:
from reservoirpy.nodes import ESN

reservoir = Reservoir(300, lr=0.5, sr=0.99, input_bias=False)
readout   = Ridge(ridge=1e-3)
esn = ESN(reservoir=reservoir, readout=readout, workers=-1)  # parallel computations: on
In [53]:
esn = esn.fit(X_train, y_train)
In [54]:
res = esn.run(X_test)
In [57]:
plot_robot_results(y_test, res)
No description has been provided for this image
In [62]:
print("Mean RMSE:", f"{np.mean(scores):.4f}", "±", f"{np.std(scores):.5f}")
print("Mean RMSE (with threshold):", f"{np.mean(filt_scores):.4f}", "±", f"{np.std(filt_scores):.5f}")
Mean RMSE: 0.1693 ± 0.10344
Mean RMSE (with threshold): 0.1443 ± 0.15187
In [81]:
acc = 0.0
for y_pred, y_true in zip(res, y_test):
    true_fall = 1 if np.max(y_true) > 0.8 else 0
    pred_fall = 1 if np.max(y_pred) > 0.8 else 0
    acc += true_fall == pred_fall
print("Accuracy: ", acc / len(y_test))
Accuracy:  0.997229916897507

Use case: anytime decoding of canary songs¶

Dataset can be found on Zenodo: https://zenodo.org/record/4736597

No description has been provided for this image

Decoded song units: phrases, which are repetitions of identical syllables.

  • One label per phrase/syllable type, with phrase onset and offset time.
  • One SIL label used to denote silence.
In [60]:
im = plt.imread("./static/canary_outputs.png")
plt.figure(figsize=(15, 15)); plt.imshow(im); plt.axis('off'); plt.show()
No description has been provided for this image

ESN training¶

In [65]:
esn = esn.fit(X_train, y_train)
In [66]:
outputs = esn.run(X_test)
In [61]:
scores  # for each song in the training set
Out[61]:
[0.041898366578752774,
 0.26824916196613413,
 0.056066923530726016,
 0.26186922452156974,
 0.25279505065020313,
 0.2973301480363923,
 0.08465112967018373,
 0.0757835005887135,
 0.062293320736262446,
 0.27808601581987047,
 0.272282108659482,
 0.06868623580211035,
 0.08142962198384064,
 0.21687078131635845,
 0.08359475517547531,
 0.25711765806631476,
 0.05928056696981127,
 0.06706991601289232,
 0.30971726758731355,
 0.2856620116257432,
 0.23803067163266287,
 0.2802303766631763,
 0.30634429769730204,
 0.07008888131638832,
 0.058834894975027814,
 0.08968311087061855,
 0.05510946353548062,
 0.31891940754185505,
 0.06637726131910554,
 0.28129431410600414,
 0.06900564046960908,
 0.08523421710318861,
 0.13232757717138613,
 0.06666551495548904,
 0.25745571933867817,
 0.23720755763796317,
 0.07878070922226743,
 0.0625701086286687,
 0.3139616547204639,
 0.06352163913912422,
 0.05993568621517934,
 0.31544378262388206,
 0.16860638504178155,
 0.07404083005603906,
 0.05400370121040267,
 0.26186867053526686,
 0.3156938730251862,
 0.3464776759964459,
 0.13554030423369498,
 0.32433297675110284,
 0.2938915124219102,
 0.054490005973757236,
 0.25388728424724843,
 0.2914105958165288,
 0.12861581219218082,
 0.25561541274249544,
 0.08774021022750428,
 0.13205119151984715,
 0.2849402588666828,
 0.0612416425822047,
 0.21963345696920855,
 0.2686952036803784,
 0.09447458524250141,
 0.3147311767427653,
 0.19993804051811317,
 0.06309105412531517,
 0.2876680332273092,
 0.2710702495770708,
 0.31157197145084503,
 0.31677865091658564,
 0.17847684814796763,
 0.0536745298782666,
 0.05763961449656545,
 0.1529957495646185,
 0.05530489056886096,
 0.0652782188606235,
 0.32759459654382095,
 0.25177867447452384,
 0.06500028117330106,
 0.055349985352973,
 0.2991405297980643,
 0.057690013962088546,
 0.06670045395220371,
 0.30107004822763334,
 0.24714774977914808,
 0.05278510393462161,
 0.06492727060552393,
 0.3130924452012823,
 0.07941626537563042,
 0.07586173824134203,
 0.14306875521238774,
 0.29999260194684224,
 0.2569699887705938,
 0.06266646755640856,
 0.06113140263957045,
 0.30826104112464153,
 0.06301119438099242,
 0.09062555280821274,
 0.31894013846914704,
 0.31108696163004074,
 0.3011042331760643,
 0.14054154336791733,
 0.31502143360411583,
 0.07179159212880215,
 0.26237244859862635,
 0.09780768918382465,
 0.30128128971088425,
 0.058959475074636565,
 0.22269091839679295,
 0.3174799889825151,
 0.11293230368504535,
 0.07604125737887724,
 0.32541925846866065,
 0.06650398004964037,
 0.22678881066897535,
 0.29061309235350924,
 0.2893223750674847,
 0.05581847677319,
 0.21938855959020445,
 0.3190807769011034,
 0.06114835692220095,
 0.05550619960719161,
 0.2992285308230232,
 0.18354501931770584,
 0.07654727028275883,
 0.13098451538222888,
 0.08222210783630346,
 0.3081013362893058,
 0.25579231705772393,
 0.22795829167918993,
 0.16028318496900773,
 0.09986265509513596,
 0.22660181937197077,
 0.23706153104418942,
 0.05898395560168648,
 0.05816246147341459,
 0.06111031878839833,
 0.04535935417557713,
 0.08345579039091174,
 0.05818286874600952,
 0.06014739331844587,
 0.2705214929637068,
 0.25337770220061434,
 0.06728420983991688,
 0.3106767233405119,
 0.0823876576456079,
 0.07207548580210814,
 0.08788158691227455,
 0.15100806105878123,
 0.19638124753817393,
 0.1172075878504895,
 0.0927045733533306,
 0.19552183350011412,
 0.057732861499848055,
 0.06296012792960531,
 0.28106645459325513,
 0.25148070867426825,
 0.05930186751695902,
 0.29155637741433305,
 0.08548715786601654,
 0.2575952097280639,
 0.07146088164251005,
 0.05828980903430679,
 0.06200418966505828,
 0.08727325904797295,
 0.30596675202953966,
 0.05664384800993739,
 0.05804706401987458,
 0.33930209996543437,
 0.06438522995406959,
 0.21920742504384375,
 0.2676698184264773,
 0.06196287186117811,
 0.06121119874255963,
 0.32298490423958875,
 0.3230490047365779,
 0.09400340251447478,
 0.2999982373975503,
 0.3054456201679897,
 0.06578350532402191,
 0.23517439821901992,
 0.16681391690414207,
 0.05450495731823852,
 0.11138643880328337,
 0.23144724266655065,
 0.04227352126156786,
 0.27746677001590647,
 0.12881643137349227,
 0.179040558230622,
 0.1341739040108836,
 0.19360408529274117,
 0.04783796456477225,
 0.32421229149130587,
 0.1845039590650022,
 0.1995124525626012,
 0.32043125531330485,
 0.06845354607533762,
 0.1458197670077707,
 0.243786017331148,
 0.052667285433581706,
 0.24152115765115628,
 0.2948981993752137,
 0.05549456640921787,
 0.05065018563903618,
 0.05784093962477063,
 0.11503369724757873,
 0.05220633118474864,
 0.27233135941218845,
 0.29630925513349,
 0.20291080002022982,
 0.27182505542537516,
 0.05674007779711345,
 0.3029427192918892,
 0.12178095050507337,
 0.06821916114471359,
 0.05960587557975479,
 0.04786283982778066,
 0.15773756384779067,
 0.06060529874784104,
 0.30264011106261013,
 0.05491380807248753,
 0.06311798606715395,
 0.058306067168067285,
 0.29951210898472896,
 0.28482520378599707,
 0.08260066083212717,
 0.10478949296695929,
 0.12399146300863935,
 0.27129600554574557,
 0.23006081901345518,
 0.2722887451720606,
 0.06192856728834221,
 0.09371245781554265,
 0.31883858241589225,
 0.06251120679635971,
 0.06232482514263497,
 0.29544290830593106,
 0.30232882561951957,
 0.31765290619017167,
 0.2089654544991453,
 0.16308052509278148,
 0.07488322936795097,
 0.28894708753493564,
 0.06384147370316678,
 0.060202920133255386,
 0.05931377049141252,
 0.06671256624571793,
 0.06042980435957483,
 0.25714713018830215,
 0.09407427887631757,
 0.27864058604941605,
 0.10288519236398824,
 0.07529194511041419,
 0.05551281765598819,
 0.14356710225809258,
 0.1324551204817483,
 0.06332605634187974,
 0.07961888295674165,
 0.061275454758553734,
 0.14147205673741706,
 0.31796210922847623,
 0.28388230888854465,
 0.06229024177158459,
 0.21333637473890732,
 0.3497920074636554,
 0.0608511471977061,
 0.044825164368904616,
 0.20390576768722812,
 0.059990589907141185,
 0.2787588800878586,
 0.28272619438987867,
 0.05969697167347789,
 0.10179927104045379,
 0.3075603731718322,
 0.30690630754661313,
 0.20715172077014768,
 0.06207893460470899,
 0.06868942288084497,
 0.26405387142666453,
 0.31312474681722235,
 0.058690663604555285,
 0.27694764961343016,
 0.2686300661668503,
 0.3167137640550521,
 0.09493423609422344,
 0.07832098335099139,
 0.06332987762799028,
 0.08927352275613562,
 0.05760170771241061,
 0.19327937494291486,
 0.3069169872279629,
 0.33795267922251676,
 0.3102212782738736,
 0.05802193906050374,
 0.18732848565545993,
 0.22806429047012128,
 0.2589426611751848,
 0.2996098525142755,
 0.05708367378781728,
 0.10157270842730629,
 0.0601797421336286,
 0.3236830802611832,
 0.19145422286558833,
 0.059729168503242745,
 0.25365998410405843,
 0.33576633450747023,
 0.06102078873007076,
 0.0809304029882015,
 0.1106371165460654,
 0.06168432562389803,
 0.3142594876642518,
 0.08394298411666416,
 0.3262661430617617,
 0.3095628341110285,
 0.07075650218735702,
 0.31562047893961037,
 0.0989195865255671,
 0.17605758224902546,
 0.05574826215846168,
 0.1161858127895835,
 0.07487089099136783,
 0.2979593522592164,
 0.3266180269052033,
 0.07278254996323831,
 0.09186819763227758,
 0.24943735725174745,
 0.06108099737033062,
 0.061227862818397885,
 0.059440613606882194,
 0.24299962136428674,
 0.25786858297852244,
 0.09752993315959181,
 0.25400978680864544,
 0.062466563820417674,
 0.24069958848296183,
 0.07077527380302613,
 0.09689059904762588,
 0.3122599027321018,
 0.2471975312709046,
 0.30182601510040263,
 0.25946153057876964,
 0.24368467767757784,
 0.11933987916397464,
 0.30764326218879195,
 0.06644954387146716,
 0.2589575593791584,
 0.0568473351911274,
 0.12133228089706342,
 0.1836830872571596,
 0.32484288159400415,
 0.04979536670779161,
 0.16857883078532115,
 0.09104091552590132,
 0.06665350184022081,
 0.3054275636857228,
 0.2416185432539922,
 0.28970576966984296,
 0.09769227160780768,
 0.05937462204769736,
 0.08793074557395594,
 0.25738072365190173]
In [69]:
print("Précision moyenne :", f"{np.mean(scores):.4f}", "±", f"{np.std(scores):.5f}")
Précision moyenne : 0.9295 ± 0.02716

Thank you! It is now yours to try¶

Nathan Trouvain
Inria, IMN, LaBRI - Bordeaux, France




UCLA - November 14th 2023